Appearance
用二维 tensor 举例最直观。
假设 dec_out_list 里有 3 个二维 tensor:
python
y0 = torch.tensor([
[1, 2],
[3, 4],
])
y1 = torch.tensor([
[10, 20],
[30, 40],
])
y2 = torch.tensor([
[100, 200],
[300, 400],
])每个 shape 都是:
python
[2, 2]现在执行:
python
torch.stack([y0, y1, y2], dim=-1)因为 dim=-1 是在最后新增一个维度,所以结果 shape 变成:
python
[2, 2, 3]结果可以理解为:
python
stacked = [
[
[1, 10, 100],
[2, 20, 200],
],
[
[3, 30, 300],
[4, 40, 400],
],
]也就是说,原来同一个位置的值被放到新维度里:
python
stacked[0, 0, :] = [y0[0,0], y1[0,0], y2[0,0]]
= [1, 10, 100]
stacked[0, 1, :] = [y0[0,1], y1[0,1], y2[0,1]]
= [2, 20, 200]然后执行:
python
stacked.sum(-1)就是对最后一维求和:
python
[
[1 + 10 + 100, 2 + 20 + 200],
[3 + 30 + 300, 4 + 40 + 400],
]结果是:
python
tensor([
[111, 222],
[333, 444],
])所以:
python
torch.stack(dec_out_list, dim=-1).sum(-1)等价于:
python
y0 + y1 + y2只是 stack + sum 可以处理任意数量的尺度预测。